Re-weighting#

PolicyEngine-UK primarily relies on the Family Resources Survey, which has known issues with non-capture of households at the bottom and top of the income distribution. To correct for this, we apply a weight modification, optimised using gradient descent to minimise survey error against a diverse selection of targeting statistics. These include:

  • Regional populations

  • Household populations

  • Population by tenure type

  • Population by Council Tax band

  • Country-level program statistics

  • UK-wide program aggregates

  • UK-wide program caseloads

The graph below shows the effect of the optimisation on each of these, compared to their starting values (under original FRS weights). All loss subfunctions improve from their starting values.

Hide code cell source
import pandas as pd
import numpy as np
import pandas as pd
import plotly.express as px

df = pd.read_csv(
    "https://github.com/PolicyEngine/openfisca-uk-reweighting/raw/master/no_val_split/training_log_run_1.csv.gz",
)
ldf = (
    df.groupby(["category", "epoch"])
    .sum()
    .reset_index()
    .pivot(columns="category", values="loss", index="epoch")
)
ldf /= ldf.loc[0]
ldf -= 1
ldf = ldf.reset_index().melt(id_vars=["epoch"])
import plotly.express as px

ldf["hover"] = [
    f"At epoch {epoch}, the total loss from targets <br>in the category <b>{category}</b> <br>has <b>{'risen' if value > 0 else 'fallen'}</b> by <b>{abs(value):.1%}</b>."
    for epoch, category, value in zip(ldf.epoch, ldf.category, ldf.value)
]

px.line(
    ldf, x="epoch", y="value", color="category", custom_data=[ldf.hover]
).update_traces(hovertemplate="%{customdata[0]}").update_layout(
    title="Training performance by category",
    height=600,
    width=800,
    xaxis_title="Epoch",
    yaxis_title="Loss change",
    legend_title="Category",
    yaxis_range=(-1, 0),
    yaxis_tickformat=".0%",
)

Changes to distributions#

Validation#

During initial training, we split the targets into training and validation groups (80%/20%), performing 5-fold cross-validation. The graph below shows the performance of validation metrics in each fold, as well as the average over the five folds.

Hide code cell source
df = pd.read_csv(
    "https://github.com/PolicyEngine/openfisca-uk-reweighting/raw/master/train_val_split/training_log.csv.gz",
    compression="gzip",
)
xdf = pd.DataFrame()
for validation_type in (True, False, "Both"):
    if isinstance(validation_type, bool):
        condition = df.validation == validation_type
    else:
        condition = df.validation | ~df.validation
    x = (
        df[condition]
        .groupby(["run_id", "epoch"])
        .loss.sum()
        .reset_index()
        .pivot(columns="run_id", values="loss", index="epoch")
    )
    x /= x.loc[0]
    x -= 1
    x = x.dropna()
    x["Average"] = x.mean(axis=1)
    x["Type"] = {
        True: "Validation",
        False: "Training",
        "Both": "Training + Validation",
    }[validation_type]
    xdf = pd.concat([xdf, x])
px.line(
    xdf,
    y=xdf.columns,
    animation_frame="Type",
    color_discrete_sequence=["lightgrey"] * 5 + ["grey"],
).update_layout(
    title="5-fold cross-validation training",
    yaxis_title="Relative loss change",
    yaxis_tickformat=".0%",
    xaxis_title="Epoch",
    legend_title="Fold",
    width=800,
    height=800,
)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[2], line 1
----> 1 df = pd.read_csv(
      2     "https://github.com/PolicyEngine/openfisca-uk-reweighting/raw/master/train_val_split/training_log.csv.gz",
      3     compression="gzip",
      4 )
      5 xdf = pd.DataFrame()
      6 for validation_type in (True, False, "Both"):

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:912, in read_csv(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)
    899 kwds_defaults = _refine_defaults_read(
    900     dialect,
    901     delimiter,
   (...)
    908     dtype_backend=dtype_backend,
    909 )
    910 kwds.update(kwds_defaults)
--> 912 return _read(filepath_or_buffer, kwds)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:577, in _read(filepath_or_buffer, kwds)
    574 _validate_names(kwds.get("names", None))
    576 # Create the parser.
--> 577 parser = TextFileReader(filepath_or_buffer, **kwds)
    579 if chunksize or iterator:
    580     return parser

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:1407, in TextFileReader.__init__(self, f, engine, **kwds)
   1404     self.options["has_index_names"] = kwds["has_index_names"]
   1406 self.handles: IOHandles | None = None
-> 1407 self._engine = self._make_engine(f, self.engine)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:1661, in TextFileReader._make_engine(self, f, engine)
   1659     if "b" not in mode:
   1660         mode += "b"
-> 1661 self.handles = get_handle(
   1662     f,
   1663     mode,
   1664     encoding=self.options.get("encoding", None),
   1665     compression=self.options.get("compression", None),
   1666     memory_map=self.options.get("memory_map", False),
   1667     is_text=is_text,
   1668     errors=self.options.get("encoding_errors", "strict"),
   1669     storage_options=self.options.get("storage_options", None),
   1670 )
   1671 assert self.handles is not None
   1672 f = self.handles.handle

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/pandas/io/common.py:716, in get_handle(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)
    713     codecs.lookup_error(errors)
    715 # open URLs
--> 716 ioargs = _get_filepath_or_buffer(
    717     path_or_buf,
    718     encoding=encoding,
    719     compression=compression,
    720     mode=mode,
    721     storage_options=storage_options,
    722 )
    724 handle = ioargs.filepath_or_buffer
    725 handles: list[BaseBuffer]

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/pandas/io/common.py:373, in _get_filepath_or_buffer(filepath_or_buffer, encoding, compression, mode, storage_options)
    370         if content_encoding == "gzip":
    371             # Override compression based on Content-Encoding header
    372             compression = {"method": "gzip"}
--> 373         reader = BytesIO(req.read())
    374     return IOArgs(
    375         filepath_or_buffer=reader,
    376         encoding=encoding,
   (...)
    379         mode=fsspec_mode,
    380     )
    382 if is_fsspec_url(filepath_or_buffer):

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/http/client.py:476, in HTTPResponse.read(self, amt)
    474 else:
    475     try:
--> 476         s = self._safe_read(self.length)
    477     except IncompleteRead:
    478         self._close_conn()

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/http/client.py:626, in HTTPResponse._safe_read(self, amt)
    624 s = []
    625 while amt > 0:
--> 626     chunk = self.fp.read(min(amt, MAXAMOUNT))
    627     if not chunk:
    628         raise IncompleteRead(b''.join(s), amt)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/socket.py:704, in SocketIO.readinto(self, b)
    702 while True:
    703     try:
--> 704         return self._sock.recv_into(b)
    705     except timeout:
    706         self._timeout_occurred = True

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/ssl.py:1242, in SSLSocket.recv_into(self, buffer, nbytes, flags)
   1238     if flags != 0:
   1239         raise ValueError(
   1240           "non-zero flags not allowed in calls to recv_into() on %s" %
   1241           self.__class__)
-> 1242     return self.read(nbytes, buffer)
   1243 else:
   1244     return super().recv_into(buffer, nbytes, flags)

File /opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/ssl.py:1100, in SSLSocket.read(self, len, buffer)
   1098 try:
   1099     if buffer is not None:
-> 1100         return self._sslobj.read(len, buffer)
   1101     else:
   1102         return self._sslobj.read(len)

KeyboardInterrupt: 

The below chart visualises the effect of the training process on each individual training and validation metric, by epoch.

Hide code cell source
df["rel_error"] = df.pred / df.actual - 1
df["Type"] = np.where(df.validation, "Validation", "Training")
STEP_SIZE = 50

cdf = df[df.epoch % STEP_SIZE == 0]
cdf = cdf[
    (cdf.category == "Budgetary impact")
    | (cdf.category == "UK-wide program aggregates")
]

fig = px.scatter(
    cdf,
    animation_frame="epoch",
    x="actual",
    y="rel_error",
    color="Type",
    hover_data=df.columns,
    opacity=0.2,
)
layout = dict(
    title="Target metrics",
    width=800,
    height=800,
    legend_title="Type",
    yaxis_title="Relative error",
    yaxis_tickformat=".1%",
    xaxis_tickprefix="£",
    xaxis_title="Actual value",
    yaxis_range=(-1, 1),
)
fig.update_layout(**layout)

for i, frame in enumerate(fig.frames):
    frame.layout.update(layout)
    frame.layout[
        "title"
    ] = f"Budgetary impact target metric performance at {i * STEP_SIZE:,} epochs"

for step in fig.layout.sliders[0].steps:
    step["args"][1]["frame"]["redraw"] = True

for button in fig.layout.updatemenus[0].buttons:
    button["args"][1]["frame"]["redraw"] = True

import gif
import plotly.graph_objects as go

gif.save(
    [
        gif.frame(lambda: go.Figure(data=frame.data, layout=frame.layout))()
        for frame in fig.frames
    ],
    "scatterplot.gif",
    duration=3_000 / len(fig.frames),
)

fig